from transformers import PretrainedConfig
from typing import List

class LMConfig(PretrainedConfig):
    model_type = "GPT-2.5"
    def __init__(self, 
                 dim: int = 2048,
                 n_layers: int = 128,
                 n_heads: int = 128,
                 n_kv_heads: int = 64,
                 dropout: float = 0.2,
                 flash_attn: bool = True,
                 hidden_dim : int = 2048,
                 vocab_size: int = 50257,
                 max_seq_len:int = 4096,
                 rms_norm_eps: float = 1e-6,
                 **kwargs

):
        super().__init__(**kwargs)
        self.dim = dim
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads
        self.dropout = dropout
        self.flash_attn = flash_attn
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.max_seq_len = max_seq_len
        self.rms_norm_eps = rms_norm_eps
